-
Notifications
You must be signed in to change notification settings - Fork 559
XLAShardedTensor.to_local() support #9505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
f0c89b9
to
933a964
Compare
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this method is necessary if there's no additional setup logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this to maintain consistency with other tests in the spmd folder, but happy to remove it if you think simplicity is more important than consistency here.
# All gradients should be 1.0 since we did a sum() | ||
self.assertTrue(torch.allclose(local_tensor.grad, torch.ones_like(tensor))) | ||
|
||
print("Gradient flow test successful") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
…ure proper XLA support and maintain consistency with PyTorch/XLA SPMD integration.
812a69a
to
6dc2351
Compare
The implementation adds a to_local() method to XLAShardedTensor class that converts a sharded tensor back to its local representation while preserving gradient information.
requires_grad property is preserved
gradients are properly calculated and maintained
backward pass works correctly through the local tensor
gradient values are accurately preserved